from AL_Dataset import ActiveLearning_Framework
import torch
from torch.nn import Module
import torch.nn as nn
from torch.nn import functional as F
from torch.utils import data
from torch.utils.data import DataLoader
from tqdm import tqdm
import utils
import numpy as np
from sklearn.metrics import pairwise_distances
from sklearn import svm
import math
from matplotlib import pyplot as plt
import methods as opt
from temperature_scaling import ModelWithTemperature

'''
Our Method:
Sparse Coreset
'''
class SparseCoreset:
    def __init__(self, model: Module, ALset:ActiveLearning_Framework, num_queried_each_round, batch_size, device, num_class, projection_method, optimization_method, alpha, beta):
        """
        use SABAL as query strategy
        :param: model: (nn.Module) current model
        :param: ALset: (class: ActiveLearning_Framework)
        :param: num_queried_each_round: (int) number of queried data each AL iteration
        :param: batch_size: (int) dataloader batchsize
        :param: device: (torch.device) GPU/CPU 
        :param: num_class: (int) number of image class in dataset
        :param: projection_method: (str) "gradient" - use gradient embedding
        :param: optimization_method: (str) to use greedy or proximal IHT for sparse approximation
        :param: alpha: (float) scaling factor for the variance term
        :param: beta: (float) scaling factor for the weight regularizer
        """
        self.model = model
        self.ALset = ALset
        self.num_queried = num_queried_each_round
        self.batch_size = batch_size
        self.device = device
        self.num_class = num_class
        labeled_pool = self.ALset.get_train_dataset_AL()
        unlabeled_pool = self.ALset.get_unlabeled_dataset_AL()
        self.unlabeled_pool = unlabeled_pool
        self.labeled_pool = labeled_pool
        self.labeled_length = len(labeled_pool)
        self.unlabeled_length = len(unlabeled_pool)
        self.dataloader = torch.utils.data.DataLoader(unlabeled_pool, batch_size=1024, shuffle=False) 
        self.projection_method = projection_method
        # for projection part
        self.pseudo_labels = torch.zeros((len(unlabeled_pool), num_class), device=device)
        self.embedding = None
        self.expected_embedding = None
        self.idx_selected = []
        self.variance = torch.zeros(len(unlabeled_pool))
        # for optimization
        self.optimization_method = optimization_method
        self.alpha = alpha
        self.beta = beta
        # for calibration
        self.temperature = 1

 
    def query(self):
        self.calibration()      #to get more calibrated label distribution
        self.inference()        

        # compute variance, and get projection
        if self.projection_method == 'gradient':
            self.embedding, self.expected_embedding = self.gradient_projection()
            self.compute_variance_gradient()

        # coreset approximation
        k = self.num_queried
        if self.projection_method == 'gradient':
            Phi = self.expected_embedding.T
            y = Phi.sum(dim=1).reshape([-1, 1])/self.unlabeled_length
            Phi /= k
            Phi = Phi / y.norm()
            y = y / y.norm()

        variance_idx = torch.argsort(self.variance, descending=True)[:k]
        variance = self.variance / torch.sum(self.variance[variance_idx])
        print( 'mean varance/largest varaince\n', torch.mean(self.variance)*k/torch.sum(self.variance[variance_idx]) )

        if self.optimization_method == 'prox_iht':
            w, supp = opt.proximal_iht(Phi, y, variance, k, self.alpha, self.beta, verbose=True, reg_type='one')                
        elif self.optimization_method == 'greedy':
            sigma = Phi.pow(2).sum(dim=0).pow(0.5)
            #sigma = torch.ones(self.unlabeled_length, dtype=Phi.dtype)
            L = sigma.sum().item()
            w, supp = opt.greedy(Phi, y, variance, k, self.alpha, sigma, L, self.beta, verbose=True, reg_type='one')             
        else:  
            raise ValueError

        # optimization results
        print('optimization results:')
        print('{} items are selected; maximal coreset size k = {}.'.format(len(supp), k))
        weights_display_num = min(len(supp), 15)
        print('weights selected (first {} are displayed) are {} ...'.format(weights_display_num, w[supp[:weights_display_num]].reshape([-1])))
        w_normalize = w.reshape([-1])  # ideally no need to normalize
        print('weights after normalization (first {} are displayed) are {} ...'.format(weights_display_num, w_normalize[supp[:weights_display_num]]))
        w_min = w_normalize[supp].min().item()
        w_max = w_normalize[supp].max().item()
        mean_deviation = (w_normalize[supp] - 1).abs().mean().item()
        print('weights has minimum {}, maximum {}, and mean deviation {}'.format(w_min, w_max, mean_deviation))
        f, f1, f2 = opt.obj(y, Phi, variance, self.alpha, w, supp)
        print('training objective (f1 + alpha * f2) is {}, approximation loss (f1) is {}, selected variance loss (alpha * f2) is {}, selected original variance (f2) is {}'.format(f, f1, f2, f2 / (self.alpha+1e-30)))
        if self.projection_method=='gradient':
            results = {'train_total_loss': f, 'train_approximation_loss': f1, 'train_variance_loss': f2,
                       'w_min': w_min, 'w_max': w_max, 'w_mean_deviation': mean_deviation, 'w': w.cpu().numpy(), 'w_normalize': w_normalize.cpu().numpy(),
                       'supp': supp, 'selected_coreset_size': len(supp), 'maximal_coreset_size':k,
                       'number_proj_train': Phi.shape[0], 'alpha': self.alpha, 'beta': self.beta
                      }
        # data query
        idx_selected = [self.ALset.unlabeled_idx[i] for i in supp]
        w_normalize = w_normalize[supp]
        self.ALset.Update_AL_Datapool(idx_selected, w_normalize)
        self.ALset.optim_results = results

    def calibration(self):
        print('Starting calibration...\n')
        valid_set = self.ALset.get_validation_dataset()
        validationloader = torch.utils.data.DataLoader(valid_set, batch_size=1024, shuffle=True)
        scaled_model = ModelWithTemperature(self.model)
        scaled_model.set_temperature(validationloader)
        self.temperature = scaled_model.temperature.item()

    def inference(self):
        print('Starting inference...\n')
        self.model.eval()
        with torch.no_grad():
            pointer = 0
            for data in tqdm(self.dataloader):
                images, labels = data[0].to(self.device), data[1].to(self.device)
                outputs, _, _ = self.model(images)
                probability = utils.output_softmax(outputs/self.temperature)
                self.pseudo_labels[pointer: pointer + len(labels)] = probability
                pointer += len(labels)
        print('Finished inference!\n')

    def gradient_projection(self):
        """
        if projection_method == "gradient", compute gradient embedding w.r.t. expected loss
        """
        embDim = self.model.get_embedding_dim()
        embedding = torch.zeros([len(self.unlabeled_pool), embDim * self.num_class, self.num_class])
        expected_embedding = torch.zeros([len(self.unlabeled_pool), embDim * self.num_class])
        self.model.eval()
        with torch.no_grad():
            pointer = 0
            for data in tqdm(self.dataloader):
                images, labels = data[0].to(self.device), data[1].to(self.device)
                outputs, _, out = self.model(images)
                out = out.data.cpu()
                batchProbs = F.softmax(outputs, dim=1).data.cpu()
                for j in range(len(labels)):
                    embedding_pool = [[], []]
                    for c in range(self.num_class):
                        embedding_pool[0].append(out[j] * (1 - batchProbs[j][c]))
                        embedding_pool[1].append(out[j] * (-1 * batchProbs[j][c]))

                    for c in range(self.num_class):
                        for k in range(self.num_class):
                            if k == c:
                                embedding[pointer+j,embDim * k : embDim * (k+1), c] = embedding_pool[0][k]
                            else:
                                embedding[pointer+j,embDim * k : embDim * (k+1), c] = embedding_pool[1][k]
                    expected_embedding[pointer+j] = torch.matmul(embedding[pointer+j],self.pseudo_labels[pointer+j].cpu())
                pointer = pointer+len(labels)
        return embedding, expected_embedding

   
    def compute_variance_gradient(self):
        """
        if projection_method == "gradient", compute the variance term
        """
        for i in range(self.unlabeled_length):
            if (i+1) % 10000 == 0: 
                print('calculated variance ', i+1, ' times')
            self.variance[i] = (self.pseudo_labels[i, :].cpu() * torch.norm(self.expected_embedding[i].view((-1,1))-self.embedding[i,:], dim=0)**2).sum()

        print('\n', 'variances', torch.max(self.variance), torch.mean(self.variance), torch.min(self.variance), '\n')
        print('Finished calculating variance!\n')

   